# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from functools import partial
import warnings

import tqdm

import jax
from jax import jit, lax, random
from jax.example_libraries import optimizers
import jax.numpy as jnp

from numpyro.distributions import constraints
from numpyro.distributions.transforms import biject_to
from numpyro.handlers import replay, seed, substitute, trace
from numpyro.infer.util import helpful_support_errors, transform_fn
from numpyro.optim import _NumPyroOptim, optax_to_numpyro
from numpyro.util import find_stack_level

SVIState = namedtuple("SVIState", ["optim_state", "mutable_state", "rng_key"])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
 - **optim_state** - current optimizer's state.
 - **mutable_state** - extra state to store values of `"mutable"` sites
 - **rng_key** - random number generator seed used for the iteration.
"""


SVIRunResult = namedtuple("SVIRunResult", ["params", "state", "losses"])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
 - **params** - the optimized parameters.
 - **state** - the last :data:`SVIState`
 - **losses** - the losses collected at every step.
"""


def _make_loss_fn(
    elbo,
    rng_key,
    constrain_fn,
    model,
    guide,
    args,
    kwargs,
    static_kwargs,
    mutable_state=None,
):
    def loss_fn(params):
        params = constrain_fn(params)
        if mutable_state is not None:
            params.update(jax.lax.stop_gradient(mutable_state))
            result = elbo.loss_with_mutable_state(
                rng_key, params, model, guide, *args, **kwargs, **static_kwargs
            )
            return result["loss"], result["mutable_state"]
        else:
            return (
                elbo.loss(
                    rng_key, params, model, guide, *args, **kwargs, **static_kwargs
                ),
                None,
            )

    return loss_fn


class SVI(object):
    """
    Stochastic Variational Inference given an ELBO loss objective.

    **References**

    1. *SVI Part I: An Introduction to Stochastic Variational Inference in Pyro*,
       (http://pyro.ai/examples/svi_part_i.html)

    **Example:**

    .. doctest::

        >>> from jax import random
        >>> import jax.numpy as jnp
        >>> import numpyro
        >>> import numpyro.distributions as dist
        >>> from numpyro.distributions import constraints
        >>> from numpyro.infer import Predictive, SVI, Trace_ELBO

        >>> def model(data):
        ...     f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
        ...     with numpyro.plate("N", data.shape[0] if data is not None else 10):
        ...         numpyro.sample("obs", dist.Bernoulli(f), obs=data)

        >>> def guide(data):
        ...     alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
        ...     beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
        ...                            constraint=constraints.positive)
        ...     numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

        >>> data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
        >>> optimizer = numpyro.optim.Adam(step_size=0.0005)
        >>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
        >>> svi_result = svi.run(random.PRNGKey(0), 2000, data)
        >>> params = svi_result.params
        >>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
        >>> # use guide to make predictive
        >>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
        >>> samples = predictive(random.PRNGKey(1), data=None)
        >>> # get posterior samples
        >>> predictive = Predictive(guide, params=params, num_samples=1000)
        >>> posterior_samples = predictive(random.PRNGKey(1), data=None)
        >>> # use posterior samples to make predictive
        >>> predictive = Predictive(model, posterior_samples, params=params, num_samples=1000)
        >>> samples = predictive(random.PRNGKey(1), data=None)

    :param model: Python callable with Pyro primitives for the model.
    :param guide: Python callable with Pyro primitives for the guide
        (recognition network).
    :param optim: An instance of :class:`~numpyro.optim._NumpyroOptim`, a
        ``jax.example_libraries.optimizers.Optimizer`` or an Optax
        ``GradientTransformation``. If you pass an Optax optimizer it will
        automatically be wrapped using :func:`numpyro.optim.optax_to_numpyro`.

            >>> from optax import adam, chain, clip
            >>> svi = SVI(model, guide, chain(clip(10.0), adam(1e-3)), loss=Trace_ELBO())

    :param loss: ELBO loss, i.e. negative Evidence Lower Bound, to minimize.
    :param static_kwargs: static arguments for the model / guide, i.e. arguments
        that remain constant during fitting.
    :return: tuple of `(init_fn, update_fn, evaluate)`.
    """

    def __init__(self, model, guide, optim, loss, **static_kwargs):
        self.model = model
        self.guide = guide
        self.loss = loss
        self.static_kwargs = static_kwargs
        self.constrain_fn = None

        if isinstance(optim, _NumPyroOptim):
            self.optim = optim
        elif isinstance(optim, optimizers.Optimizer):
            self.optim = _NumPyroOptim(lambda *args: args, *optim)
        else:
            try:
                import optax
            except ImportError:
                raise ImportError(
                    "It looks like you tried to use an optimizer that isn't an "
                    "instance of numpyro.optim._NumPyroOptim or "
                    "jax.example_libraries.optimizers.Optimizer. There is experimental "
                    "support for Optax optimizers, but you need to install Optax. "
                    "It can be installed with `pip install optax`."
                )

            if not isinstance(optim, optax.GradientTransformation):
                raise TypeError(
                    "Expected either an instance of numpyro.optim._NumPyroOptim, "
                    "jax.example_libraries.optimizers.Optimizer or "
                    "optax.GradientTransformation. Got {}".format(type(optim))
                )

            self.optim = optax_to_numpyro(optim)

    def init(self, rng_key, *args, init_params=None, **kwargs):
        """
        Gets the initial SVI state.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param dict init_params: if not None, initialize :class:`numpyro.param` sites with values from
            this dictionary instead of using ``init_value`` in :class:`numpyro.param` primitives.
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: the initial :data:`SVIState`
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        if init_params is not None:
            guide_init = substitute(guide_init, init_params)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
        init_guide_params = {
            name: site["value"]
            for name, site in guide_trace.items()
            if site["type"] == "param"
        }
        if init_params is not None:
            init_guide_params.update(init_params)
        if getattr(self.loss, "multi_sample_guide", False):
            latents = {
                name: site["value"][0]
                for name, site in guide_trace.items()
                if site["type"] == "sample" and site["value"].size > 0
            }
            latents.update(init_guide_params)
            with trace() as model_trace, substitute(data=latents):
                model_init(*args, **kwargs, **self.static_kwargs)
            for site in model_trace.values():
                if site["type"] == "mutable":
                    raise ValueError(
                        "mutable state in model is not supported for "
                        "multi-sample guide."
                    )
        else:
            model_trace = trace(
                substitute(replay(model_init, guide_trace), init_guide_params)
            ).get_trace(*args, **kwargs, **self.static_kwargs)

        params = {}
        inv_transforms = {}
        mutable_state = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site["type"] == "param":
                constraint = site["kwargs"].pop("constraint", constraints.real)
                with helpful_support_errors(site):
                    transform = biject_to(constraint)
                inv_transforms[site["name"]] = transform
                params[site["name"]] = transform.inv(site["value"])
            elif site["type"] == "mutable":
                mutable_state[site["name"]] = site["value"]
            elif (
                site["type"] == "sample"
                and (not site["is_observed"])
                and site["fn"].support.is_discrete
                and not self.loss.can_infer_discrete
            ):
                s_name = type(self.loss).__name__
                warnings.warn(
                    f"Currently, SVI with {s_name} loss does not support models with discrete latent variables",
                    stacklevel=find_stack_level(),
                )

        if not mutable_state:
            mutable_state = None
        self.constrain_fn = partial(transform_fn, inv_transforms)
        # we convert weak types like float to float32/float64
        # to avoid recompiling body_fn in svi.run
        params, mutable_state = jax.tree.map(
            lambda x: lax.convert_element_type(x, jnp.result_type(x)),
            (params, mutable_state),
        )
        return SVIState(self.optim.init(params), mutable_state, rng_key)

    def get_params(self, svi_state):
        """
        Gets values at `param` sites of the `model` and `guide`.

        :param svi_state: current state of SVI.
        :return: the corresponding parameters
        """
        params = self.constrain_fn(self.optim.get_params(svi_state.optim_state))
        return params

    def update(self, svi_state, *args, forward_mode_differentiation=False, **kwargs):
        """
        Take a single step of SVI (possibly on a batch / minibatch of data),
        using the optimizer.

        :param svi_state: current state of SVI.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
            Defaults to False.
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple of `(svi_state, loss)`.
        """
        rng_key, rng_key_step = random.split(svi_state.rng_key)
        loss_fn = _make_loss_fn(
            self.loss,
            rng_key_step,
            self.constrain_fn,
            self.model,
            self.guide,
            args,
            kwargs,
            self.static_kwargs,
            mutable_state=svi_state.mutable_state,
        )
        (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
            loss_fn,
            svi_state.optim_state,
            forward_mode_differentiation=forward_mode_differentiation,
        )
        return SVIState(optim_state, mutable_state, rng_key), loss_val

    def stable_update(
        self, svi_state, *args, forward_mode_differentiation=False, **kwargs
    ):
        """
        Similar to :meth:`update` but returns the current state if the
        the loss or the new state contains invalid values.

        :param svi_state: current state of SVI.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
            Defaults to False.
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple of `(svi_state, loss)`.
        """
        rng_key, rng_key_step = random.split(svi_state.rng_key)
        loss_fn = _make_loss_fn(
            self.loss,
            rng_key_step,
            self.constrain_fn,
            self.model,
            self.guide,
            args,
            kwargs,
            self.static_kwargs,
            mutable_state=svi_state.mutable_state,
        )
        (loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update(
            loss_fn,
            svi_state.optim_state,
            forward_mode_differentiation=forward_mode_differentiation,
        )
        return SVIState(optim_state, mutable_state, rng_key), loss_val

    def run(
        self,
        rng_key,
        num_steps,
        *args,
        progress_bar=True,
        stable_update=False,
        forward_mode_differentiation=False,
        init_state=None,
        init_params=None,
        **kwargs,
    ):
        """
        (EXPERIMENTAL INTERFACE) Run SVI with `num_steps` iterations, then return
        the optimized parameters and the stacked losses at every step. If `num_steps`
        is large, setting `progress_bar=False` can make the run faster.

        .. note:: For a complex training process (e.g. the one requires early stopping,
            epoch training, varying args/kwargs,...), we recommend to use the more
            flexible methods :meth:`init`, :meth:`update`, :meth:`evaluate` to
            customize your training procedure.

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param int num_steps: the number of optimization steps.
        :param args: arguments to the model / guide
        :param bool progress_bar: Whether to enable progress bar updates. Defaults to
            ``True``.
        :param bool stable_update: whether to use :meth:`stable_update` to update
            the state. Defaults to False.
        :param bool forward_mode_differentiation: whether to use forward-mode differentiation
            or reverse-mode differentiation. By default, we use reverse mode but the forward
            mode can be useful in some cases to improve the performance. In addition, some
            control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
            only supports forward-mode differentiation. See
            `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
            for more information.
        :param SVIState init_state: if not None, begin SVI from the
            final state of previous SVI run. Usage::

                svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
                svi_result = svi.run(random.PRNGKey(0), 2000, data)
                # upon inspection of svi_result the user decides that the model has not converged
                # continue from the end of the previous svi run rather than beginning again from iteration 0
                svi_result = svi.run(random.PRNGKey(1), 2000, data, init_state=svi_result.state)

        :param dict init_params: if not None, initialize :class:`numpyro.param` sites with values from
            this dictionary instead of using ``init_value`` in :class:`numpyro.param` primitives.
        :param kwargs: keyword arguments to the model / guide
        :return: a namedtuple with fields `params` and `losses` where `params`
            holds the optimized values at :class:`numpyro.param` sites,
            and `losses` is the collected loss during the process.
        :rtype: :data:`SVIRunResult`
        """

        if num_steps < 1:
            raise ValueError("num_steps must be a positive integer.")

        def body_fn(svi_state, _):
            if stable_update:
                svi_state, loss = self.stable_update(
                    svi_state,
                    *args,
                    forward_mode_differentiation=forward_mode_differentiation,
                    **kwargs,
                )
            else:
                svi_state, loss = self.update(
                    svi_state,
                    *args,
                    forward_mode_differentiation=forward_mode_differentiation,
                    **kwargs,
                )
            return svi_state, loss

        if init_state is None:
            svi_state = self.init(rng_key, *args, init_params=init_params, **kwargs)
        else:
            svi_state = init_state
        if progress_bar:
            losses = []
            with tqdm.trange(1, num_steps + 1) as t:
                batch = max(num_steps // 20, 1)
                for i in t:
                    svi_state, loss = jit(body_fn)(svi_state, None)
                    losses.append(jax.device_get(loss))
                    if i % batch == 0:
                        if stable_update:
                            valid_losses = [x for x in losses[i - batch :] if x == x]
                            num_valid = len(valid_losses)
                            if num_valid == 0:
                                avg_loss = float("nan")
                            else:
                                avg_loss = sum(valid_losses) / num_valid
                        else:
                            avg_loss = sum(losses[i - batch :]) / batch
                        t.set_postfix_str(
                            "init loss: {:.4f}, avg. loss [{}-{}]: {:.4f}".format(
                                losses[0], i - batch + 1, i, avg_loss
                            ),
                            refresh=False,
                        )
            losses = jnp.stack(losses)
        else:
            svi_state, losses = lax.scan(body_fn, svi_state, None, length=num_steps)

        # XXX: we also return the last svi_state for further inspection of both
        # optimizer's state and mutable state.
        return SVIRunResult(self.get_params(svi_state), svi_state, losses)

    def evaluate(self, svi_state, *args, **kwargs):
        """
        Take a single step of SVI (possibly on a batch / minibatch of data).

        :param svi_state: current state of SVI.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide.
        :return: evaluate ELBO loss given the current parameter values
            (held within `svi_state.optim_state`).
        """
        # we split to have the same seed as `update_fn` given an svi_state
        _, rng_key_eval = random.split(svi_state.rng_key)
        params = self.get_params(svi_state)
        return self.loss.loss(
            rng_key_eval,
            params,
            self.model,
            self.guide,
            *args,
            **kwargs,
            **self.static_kwargs,
        )
